{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "abc_mnist.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      }
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A-JRAXTrrgBG"
      },
      "source": [
        "Copyright 2021 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "      https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MsunizoxDgdD"
      },
      "source": [
        "# Approximate Bijective Correspondence (ABC) with MNIST\n",
        "\n",
        "ABC seeks correspondence between input sets of data which have been grouped by inactive factor of variation.  In the case of MNIST, the data has been grouped by digit class, leaving style as the active factor of variation to embed."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NMJlY_CHfqFC",
        "cellView": "form"
      },
      "source": [
        "#@title Imports\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "%matplotlib inline\n",
        "import tensorflow_datasets as tfds\n",
        "from matplotlib.offsetbox import OffsetImage, AnnotationBbox\n",
        "\n",
        "from sklearn.decomposition import PCA\n",
        "\n",
        "tfkl = tf.keras.layers"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5HJZITJef4Of"
      },
      "source": [
        "# The number of images of a digit in each of the two sets\n",
        "# This impacts the level of detail to which the network is sensitive.  You could\n",
        "# imagine finding correspondence between stacks of 4 would require coarser\n",
        "# details than stacks of 64.\n",
        "stack_size = 64\n",
        "\n",
        "# The dimension of the embedding space\n",
        "num_latent_dims = 8\n",
        "\n",
        "# The similarity type to use as the distance metric in embedding space\n",
        "# Available options:\n",
        "# l2 : Negative Euclidean distance\n",
        "# l2sq : Negative squared Euclidean distance\n",
        "# l1 : Neggative L1 distance ('Manhattan' distance)\n",
        "# linf : ord = inf distance (the max displacement along any coordinate), negated\n",
        "# cosine : cosine similarity, bounded between -1 and 1\n",
        "similarity_type = 'l2sq'\n",
        "\n",
        "# The digit to withold during training\n",
        "test_digit = 9\n",
        "\n",
        "optimizer_name = 'adam'\n",
        "lr = 1e-4\n",
        "num_steps = 500\n",
        "temperature = 1.  # essentially ineffective unless using cosine similarity, just sets the length scale in embedding space\n",
        "imgs_to_plot = 400  # for displaying the embeddings via pca\n",
        "output_plots_during_training = True\n",
        "output_loss_every = 100"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yBt_sgksmyrs",
        "cellView": "form"
      },
      "source": [
        "#@title Load MNIST into 10 digit-specific datasets\n",
        "dset = tfds.load('mnist', split='train')\n",
        "dset = dset.map(lambda example: (tf.cast(example['image'], tf.float32)/255., example['label']),\n",
        "                num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "\n",
        "ds = [dset.filter(lambda x, y: y==i) for i in range(10)]\n",
        "ds = [d.map(lambda x, y: x).shuffle(1000).repeat().batch(stack_size) for d in ds]\n",
        "\n",
        "dset_test = tfds.load('mnist', split='test')\n",
        "dset_test = dset_test.map(lambda example: (tf.cast(example['image'], tf.float32)/255., example['label']),\n",
        "                num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "\n",
        "ds_test = [dset_test.filter(lambda x, y: y==i) for i in range(10)]\n",
        "ds_test = [d.map(lambda x, y: x).batch(stack_size) for d in ds_test]\n",
        "\n",
        "# Combine stacks from different digits randomly (sometimes a digit is paired with itself but this does not derail training).\n",
        "# The shape of each element is [2, stack_size, 28, 28, 1].\n",
        "combined_dset = tf.data.experimental.sample_from_datasets(ds[:test_digit]+ds[test_digit+1:]).batch(2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WFQrTGtugO-C",
        "cellView": "form"
      },
      "source": [
        "#@title The data is grouped by class label; this is all the supervision needed to learn about writing style.\n",
        "for d in ds:\n",
        "  for img_stack in d.take(1):\n",
        "    plt.figure(figsize=(9, 1))\n",
        "    for j in range(8):\n",
        "      plt.subplot(1, 8, j+1)\n",
        "      plt.imshow(img_stack[j, ..., 0], cmap='binary')\n",
        "      plt.xticks([]); plt.yticks([])\n",
        "    plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HF5975h1gdat"
      },
      "source": [
        "# The embedding model\n",
        "model = tf.keras.Sequential([\n",
        "                             tfkl.Input(shape=(28, 28, 1)),\n",
        "                             tfkl.Conv2D(32, 3, activation='relu'),\n",
        "                             tfkl.Conv2D(32, 3, activation='relu'),\n",
        "                             tfkl.Conv2D(32, 3, activation='relu', strides=2),\n",
        "                             tfkl.Conv2D(32, 3, activation='relu'),\n",
        "                             tfkl.Conv2D(32, 3, activation='relu'),\n",
        "                             tfkl.Flatten(),\n",
        "                             tfkl.Dense(128, activation='relu'),\n",
        "                             tfkl.Dense(num_latent_dims, activation='linear'),\n",
        "                             ])\n",
        "print(model.summary())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Cv1B3-bcrNo9",
        "cellView": "form"
      },
      "source": [
        "#@title Run PCA on the output of the untrained model (for comparison).\n",
        "plot_imgs = []\n",
        "embs_pre = []\n",
        "for img_stack in ds[test_digit].take(imgs_to_plot//stack_size):\n",
        "  plot_imgs.append(img_stack)\n",
        "  embs_pre.append(model(img_stack, training=False))\n",
        "plot_imgs = tf.concat(plot_imgs, 0)\n",
        "embs_pre = tf.concat(embs_pre, 0)\n",
        "\n",
        "pca_pre = PCA(n_components=2)\n",
        "pca_pre.fit(embs_pre)\n",
        "print('PCA2 explained variance before training:', pca_pre.explained_variance_ratio_)\n",
        "t_pre = pca_pre.transform(embs_pre)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hZ7rwV_zjPxF",
        "cellView": "form"
      },
      "source": [
        "#@title The heart of ABC: helper functions for computing the loss\n",
        "# Many were copied+modified from Dwibedi et al. (2019).\n",
        "@tf.function\n",
        "def pairwise_l2_distance(embs1, embs2):\n",
        "  # embs are shape [stack_size, num_latent_dims]\n",
        "  # returns shape [stack_size, stack_size] as the full matrix of distances btwn embs\n",
        "  norm1 = tf.reduce_sum(tf.square(embs1), 1)\n",
        "  norm1 = tf.reshape(norm1, [-1, 1])\n",
        "  norm2 = tf.reduce_sum(tf.square(embs2), 1)\n",
        "  norm2 = tf.reshape(norm2, [1, -1])\n",
        "  dist = tf.maximum(\n",
        "      norm1 + norm2 - 2.0 * tf.matmul(embs1, embs2, False, True), 0.0)\n",
        "  return dist\n",
        "\n",
        "@tf.function\n",
        "def pairwise_l1_distance(embs1, embs2):\n",
        "  ss2 = embs2.shape[0]\n",
        "  embs1_tiled = tf.tile(tf.expand_dims(embs1, 1), [1, ss2, 1])\n",
        "  dist = tf.reduce_sum(tf.abs(embs1_tiled-embs2), -1)\n",
        "  return dist\n",
        "\n",
        "@tf.function\n",
        "def pairwise_linf_distance(embs1, embs2):\n",
        "  ss2 = embs2.shape[0]\n",
        "  embs1_tiled = tf.tile(tf.expand_dims(embs1, 1), [1, ss2, 1])\n",
        "  dist = tf.reduce_max(tf.abs(embs1_tiled-embs2), -1)\n",
        "  return dist\n",
        "\n",
        "@tf.function\n",
        "def get_scaled_similarity(embs1, embs2, similarity_type, temperature):\n",
        "  if similarity_type == 'l2sq':\n",
        "    similarity = -1.0 * pairwise_l2_distance(embs1, embs2)\n",
        "  elif similarity_type == 'l2':\n",
        "    similarity = -1.0 * tf.sqrt(pairwise_l2_distance(embs1, embs2) + eps)\n",
        "  elif similarity_type == 'l1':\n",
        "    similarity = -1.0 * pairwise_l1_distance(embs1, embs2)\n",
        "  elif similarity_type == 'linf':\n",
        "    similarity = -1.0 * pairwise_linf_distance(embs1, embs2)\n",
        "  elif similarity_type == 'cosine':\n",
        "    embs1, _ = tf.linalg.normalize(embs1, ord=2, axis=-1)\n",
        "    embs2, _ = tf.linalg.normalize(embs2, ord=2, axis=-1)\n",
        "    similarity = tf.matmul(embs1, embs2, transpose_b=True)\n",
        "  else:\n",
        "    raise ValueError('Unknown similarity type: {}'.format(similarity_type))\n",
        "\n",
        "  similarity /= temperature\n",
        "  return similarity\n",
        "\n",
        "@tf.function\n",
        "def align_pair_of_sequences(embs1, embs2, similarity_type, temperature):\n",
        "  # Creates a soft nearest neighbor for each emb1 out of the elements of embs2\n",
        "  ss1 = tf.shape(embs1)[0]\n",
        "  sim_12 = get_scaled_similarity(embs1, embs2, similarity_type, temperature)\n",
        "  softmaxed_sim_12 = tf.nn.softmax(sim_12, axis=1)\n",
        "  nn_embs = tf.matmul(softmaxed_sim_12, embs2)\n",
        "  sim_21 = get_scaled_similarity(nn_embs, embs1, similarity_type, temperature)\n",
        "\n",
        "  loss = tf.keras.losses.sparse_categorical_crossentropy(tf.range(ss1), sim_21,\n",
        "                                                         from_logits=True)\n",
        "\n",
        "  return tf.reduce_mean(loss)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xftz7vA1ZGUx",
        "cellView": "form"
      },
      "source": [
        "#@title display_pca_embs definition\n",
        "def display_pca_embs(model, plot_imgs, step, avg_loss):\n",
        "  embs_intermed = []\n",
        "  for start_ind in range(0, len(plot_imgs), stack_size):\n",
        "    embs_intermed.append(model(plot_imgs[start_ind:start_ind+stack_size], training=False))\n",
        "  embs_intermed = tf.concat(embs_intermed, 0)\n",
        "\n",
        "  pca_intermed = PCA(n_components=2)\n",
        "  pca_intermed.fit(embs_intermed)\n",
        "  t_intermed = pca_intermed.transform(embs_intermed)\n",
        "  plt.figure(figsize=(8, 8))\n",
        "  zoom_factor = 1.  # scales the size of the individual digit images\n",
        "\n",
        "  ax = plt.gca()\n",
        "  for img_id, img in enumerate(plot_imgs):\n",
        "    img = tf.concat([1-img, 1-img, 1-img, img], -1)\n",
        "    im = OffsetImage(img, zoom=zoom_factor)\n",
        "    ab = AnnotationBbox(im, t_intermed[img_id], frameon=False)\n",
        "    ax.add_artist(ab)\n",
        "    plt.scatter(t_intermed[img_id, 0], t_intermed[img_id, 1], s=0)  # this is just so the axes bound the images\n",
        "  plt.xlabel('PC0, var {:.3f}'.format(pca_intermed.explained_variance_ratio_[0]), fontsize=14.)\n",
        "  plt.ylabel('PC1, var {:.3f}'.format(pca_intermed.explained_variance_ratio_[1]), fontsize=14.)\n",
        "  plt.title('Step {}, ABC Loss = {:.3f}'.format(step, avg_loss), fontsize=16.)\n",
        "\n",
        "  plt.show()\n",
        "  return"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tbJNU9zghzPR"
      },
      "source": [
        "# Train the embedder\n",
        "opt = tf.keras.optimizers.get(optimizer_name)\n",
        "opt.lr = lr\n",
        "\n",
        "losses = []\n",
        "\n",
        "for step, paired_stacks in enumerate(combined_dset.take(num_steps)):\n",
        "  with tf.GradientTape() as tape:\n",
        "    embs1 = model(paired_stacks[0], training=True)\n",
        "    embs2 = model(paired_stacks[1], training=True)\n",
        "    loss = align_pair_of_sequences(embs1, embs2, similarity_type, temperature)\n",
        "    loss += align_pair_of_sequences(embs2, embs1, similarity_type, temperature)\n",
        "  grads = tape.gradient(loss, model.trainable_variables)\n",
        "  opt.apply_gradients(zip(grads, model.trainable_variables))\n",
        "  losses.append(loss.numpy())\n",
        "\n",
        "  if not step % output_loss_every:\n",
        "\n",
        "    if output_plots_during_training:\n",
        "      display_pca_embs(model, plot_imgs, step, np.average(losses[-output_loss_every:]))\n",
        "    else:\n",
        "      print('Step {} Loss: {:.2f}'.format(step,\n",
        "                                          np.average(losses[-output_loss_every:])))\n",
        "if output_plots_during_training:\n",
        "  display_pca_embs(model, plot_imgs, step, np.average(losses[-output_loss_every:]))\n",
        "else:\n",
        "  print('Step {} Loss: {:.2f}'.format(step,\n",
        "                                      np.average(losses[-output_loss_every:])))\n",
        "print('Training completed.')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Mxs5vvDzrtX_",
        "cellView": "form"
      },
      "source": [
        "#@title Run PCA on the output of the trained model\n",
        "embs_post = []\n",
        "for start_ind in range(0, len(plot_imgs), stack_size):\n",
        "  if not start_ind:\n",
        "    embs_post = model(plot_imgs[start_ind:start_ind+stack_size], training=False)\n",
        "  else:\n",
        "    embs_post = tf.concat([embs_post, model(plot_imgs[start_ind:start_ind+stack_size], training=False)], 0)\n",
        "\n",
        "pca_post = PCA(n_components=2)\n",
        "pca_post.fit(embs_post)\n",
        "print('PCA2 explained variance after training:', pca_post.explained_variance_ratio_)\n",
        "t_post = pca_post.transform(embs_post)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HK_JaSvcPORG",
        "cellView": "form"
      },
      "source": [
        "#@title Compare the PCA embeddings before and after training.\n",
        "plt.figure(figsize=(16, 8))\n",
        "zoom_factor = 1.  # scales the size of the individual digit images\n",
        "\n",
        "plt.subplot(121)\n",
        "ax = plt.gca()\n",
        "for img_id, img in enumerate(plot_imgs):\n",
        "  img = tf.concat([1-img, 1-img, 1-img, img], -1)\n",
        "  im = OffsetImage(img, zoom=zoom_factor)\n",
        "  ab = AnnotationBbox(im, t_pre[img_id], frameon=False)\n",
        "  ax.add_artist(ab)\n",
        "  plt.scatter(t_pre[img_id, 0], t_pre[img_id, 1], s=0)  # this is just so the axes bound the images\n",
        "plt.xlabel('PC0, var = {:.3f}'.format(pca_pre.explained_variance_ratio_[0]), fontsize=14.)\n",
        "plt.ylabel('PC1, var = {:.3f}'.format(pca_pre.explained_variance_ratio_[1]), fontsize=14.)\n",
        "plt.title('Before training', fontsize=16.)\n",
        "\n",
        "plt.subplot(122)\n",
        "ax = plt.gca()\n",
        "for img_id, img in enumerate(plot_imgs):\n",
        "  img = tf.concat([1-img, 1-img, 1-img, img], -1)\n",
        "  im = OffsetImage(img, zoom=zoom_factor)\n",
        "  ab = AnnotationBbox(im, t_post[img_id], frameon=False)\n",
        "  ax.add_artist(ab)\n",
        "  plt.scatter(t_post[img_id, 0], t_post[img_id, 1], s=0)\n",
        "plt.xlabel('PC0, var = {:.3f}'.format(pca_post.explained_variance_ratio_[0]), fontsize=14.)\n",
        "plt.ylabel('PC1, var = {:.3f}'.format(pca_post.explained_variance_ratio_[1]), fontsize=14.)\n",
        "plt.title('After training', fontsize=16.)\n",
        "\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "b2z2A3ppZ0_i",
        "cellView": "form"
      },
      "source": [
        "#@title Check out other digits (0s and 1s are often easier to decipher).\n",
        "imgs_to_plot = 200\n",
        "digits_to_plot = [0, 1, 2]\n",
        "plt.figure(figsize=(18, 6))\n",
        "for plot_id, digit_id in enumerate(digits_to_plot):\n",
        "  imgs = []\n",
        "  embs = []\n",
        "  for ind, img_stack in enumerate(ds[digit_id].take(imgs_to_plot//stack_size)):\n",
        "    if not ind:\n",
        "      imgs = img_stack\n",
        "      embs = model(img_stack, training=False)\n",
        "    else:\n",
        "      imgs = tf.concat([imgs, img_stack], 0)\n",
        "      embs = tf.concat([embs, model(img_stack, training=False)], 0)\n",
        "  t_post = pca_post.transform(embs)\n",
        "  plt.subplot(1, 3, plot_id+1)\n",
        "  ax = plt.gca()\n",
        "  for img_id, img in enumerate(imgs):\n",
        "    img = tf.concat([1-img, 1-img, 1-img, img], -1)\n",
        "    im = OffsetImage(img, zoom=zoom_factor)\n",
        "    ab = AnnotationBbox(im, t_post[img_id], frameon=False)\n",
        "    ax.add_artist(ab)\n",
        "    plt.scatter(t_post[img_id, 0], t_post[img_id, 1], s=0.)\n",
        "\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aqKpBc0cxClz"
      },
      "source": [
        "# Perform retrieval using random test digits\n",
        "all_imgs = []; all_embs = []\n",
        "num_images_to_use = 512\n",
        "for d in ds_test:\n",
        "  for stack_id, img_stack in enumerate(d.take(num_images_to_use//stack_size)):\n",
        "    embs = model(img_stack, training=False)\n",
        "    if stack_id:\n",
        "      all_imgs[-1] = tf.concat([all_imgs[-1], img_stack], 0)\n",
        "      all_embs[-1] = tf.concat([all_embs[-1], embs], 0)\n",
        "    else:\n",
        "      all_imgs.append(img_stack)\n",
        "      all_embs.append(embs)\n",
        "# Group by the nearest example of each digit\n",
        "plt.figure(figsize=(10, 10))\n",
        "for digit in range(10):\n",
        "  template_img = all_imgs[digit][0]\n",
        "  for other_digit in range(10):\n",
        "    dists = pairwise_l2_distance(all_embs[digit][0:1], all_embs[other_digit])  # [1, stack_size]\n",
        "    min_ind = tf.argmin(dists[0])\n",
        "    plt.subplot(10, 10, digit*10 + other_digit + 1)\n",
        "    plt.imshow(all_imgs[other_digit][min_ind, ..., 0], cmap='binary')\n",
        "    plt.xticks([]); plt.yticks([])\n",
        "    ax = plt.gca()\n",
        "    plt.setp(ax.spines.values(), color='#ccdad1', linewidth=[0., 5.][digit==other_digit])\n",
        "\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e0eJZmDlqyS6"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}